{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "g_nWetWWd_ns"
   },
   "source": [
    "##### Copyright 2019 The TensorFlow Authors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "cellView": "form",
    "execution": {
     "iopub.execute_input": "2023-11-08T00:03:56.067378Z",
     "iopub.status.busy": "2023-11-08T00:03:56.067122Z",
     "iopub.status.idle": "2023-11-08T00:03:56.071576Z",
     "shell.execute_reply": "2023-11-08T00:03:56.070855Z"
    },
    "id": "2pHVBk_seED1"
   },
   "outputs": [],
   "source": [
    "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "# https://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "cellView": "form",
    "execution": {
     "iopub.execute_input": "2023-11-08T00:03:56.074747Z",
     "iopub.status.busy": "2023-11-08T00:03:56.074510Z",
     "iopub.status.idle": "2023-11-08T00:03:56.078363Z",
     "shell.execute_reply": "2023-11-08T00:03:56.077726Z"
    },
    "id": "N_fMsQ-N8I7j"
   },
   "outputs": [],
   "source": [
    "#@title MIT License\n",
    "#\n",
    "# Copyright (c) 2017 François Chollet\n",
    "#\n",
    "# Permission is hereby granted, free of charge, to any person obtaining a\n",
    "# copy of this software and associated documentation files (the \"Software\"),\n",
    "# to deal in the Software without restriction, including without limitation\n",
    "# the rights to use, copy, modify, merge, publish, distribute, sublicense,\n",
    "# and/or sell copies of the Software, and to permit persons to whom the\n",
    "# Software is furnished to do so, subject to the following conditions:\n",
    "#\n",
    "# The above copyright notice and this permission notice shall be included in\n",
    "# all copies or substantial portions of the Software.\n",
    "#\n",
    "# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n",
    "# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n",
    "# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL\n",
    "# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n",
    "# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n",
    "# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\n",
    "# DEALINGS IN THE SOFTWARE."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "pZJ3uY9O17VN"
   },
   "source": [
    "# 保存和恢复模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "mBdde4YJeJKF"
   },
   "source": [
    "可以在训练期间和之后保存模型进度。这意味着模型可以从停止的地方恢复，避免长时间的训练。此外，保存还意味着您可以分享您的模型，其他人可以重现您的工作。在发布研究模型和技术时，大多数机器学习从业者会分享：\n",
    "\n",
    "- 用于创建模型的代码\n",
    "- 模型的训练权重或形参\n",
    "\n",
    "共享数据有助于其他人了解模型的工作原理，并使用新数据自行尝试。\n",
    "\n",
    "小心：TensorFlow 模型是代码，对于不受信任的代码，一定要小心。请参阅 [安全使用 TensorFlow](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md) 以了解详情。\n",
    "\n",
    "### 选项\n",
    "\n",
    "根据您使用的 API，可以通过不同的方式保存 TensorFlow 模型。本指南使用 [tf.keras](https://tensorflow.google.cn/guide/keras) – 一种用于在 TensorFlow 中构建和训练模型的高级 API。建议使用本教程中使用的新的高级 `.keras` 格式来保存 Keras 对象，因为它提供了强大、高效的基于名称的保存，通常比低级或旧版格式更容易调试。如需更高级的保存或序列化工作流，尤其是那些涉及自定义对象的工作流，请参阅[保存和加载 Keras 模型指南](https://tensorflow.google.cn/guide/keras/save_and_serialize)。对于其他方式，请参阅[使用 SavedModel 格式指南](../../guide/saved_model.ipynb)。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xCUREq7WXgvg"
   },
   "source": [
    "## 配置\n",
    "\n",
    "### 安装并导入"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7l0MiTOrXtNv"
   },
   "source": [
    "安装并导入Tensorflow和依赖项："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:03:56.082437Z",
     "iopub.status.busy": "2023-11-08T00:03:56.081875Z",
     "iopub.status.idle": "2023-11-08T00:03:58.152894Z",
     "shell.execute_reply": "2023-11-08T00:03:58.151889Z"
    },
    "id": "RzIOVSdnMYyO"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: pyyaml in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (6.0.1)\r\n",
      "Requirement already satisfied: h5py in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (3.10.0)\r\n",
      "Requirement already satisfied: numpy>=1.17.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from h5py) (1.26.1)\r\n"
     ]
    }
   ],
   "source": [
    "!pip install pyyaml h5py  # Required to save models in HDF5 format"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:03:58.157428Z",
     "iopub.status.busy": "2023-11-08T00:03:58.157126Z",
     "iopub.status.idle": "2023-11-08T00:04:00.748993Z",
     "shell.execute_reply": "2023-11-08T00:04:00.748205Z"
    },
    "id": "7Nm7Tyb-gRt-"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-11-08 00:03:58.631423: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
      "2023-11-08 00:03:58.631475: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
      "2023-11-08 00:03:58.633260: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.15.0-rc1\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "\n",
    "print(tf.version.VERSION)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SbGsznErXWt6"
   },
   "source": [
    "### 获取示例数据集\n",
    "\n",
    "为了演示如何保存和加载权重，您将使用 [MNIST 数据集](http://yann.lecun.com/exdb/mnist/)。为了加快运行速度，请使用前 1000 个样本："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:04:00.753152Z",
     "iopub.status.busy": "2023-11-08T00:04:00.752536Z",
     "iopub.status.idle": "2023-11-08T00:04:01.073866Z",
     "shell.execute_reply": "2023-11-08T00:04:01.072867Z"
    },
    "id": "9rGfFwE9XVwz"
   },
   "outputs": [],
   "source": [
    "(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()\n",
    "\n",
    "train_labels = train_labels[:1000]\n",
    "test_labels = test_labels[:1000]\n",
    "\n",
    "train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0\n",
    "test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "anG3iVoXyZGI"
   },
   "source": [
    "### 定义模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wynsOBfby0Pa"
   },
   "source": [
    "首先构建一个简单的序列（sequential）模型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:04:01.078897Z",
     "iopub.status.busy": "2023-11-08T00:04:01.078287Z",
     "iopub.status.idle": "2023-11-08T00:04:03.456375Z",
     "shell.execute_reply": "2023-11-08T00:04:03.455574Z"
    },
    "id": "0HZbJIjxyX1S"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " Layer (type)                Output Shape              Param #   \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=================================================================\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " dense (Dense)               (None, 512)               401920    \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                                                                 \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " dropout (Dropout)           (None, 512)               0         \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                                                                 \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " dense_1 (Dense)             (None, 10)                5130      \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                                                                 \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=================================================================\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total params: 407050 (1.55 MB)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Trainable params: 407050 (1.55 MB)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Non-trainable params: 0 (0.00 Byte)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "# Define a simple sequential model\n",
    "def create_model():\n",
    "  model = tf.keras.Sequential([\n",
    "    keras.layers.Dense(512, activation='relu', input_shape=(784,)),\n",
    "    keras.layers.Dropout(0.2),\n",
    "    keras.layers.Dense(10)\n",
    "  ])\n",
    "\n",
    "  model.compile(optimizer='adam',\n",
    "                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
    "                metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])\n",
    "\n",
    "  return model\n",
    "\n",
    "# Create a basic model instance\n",
    "model = create_model()\n",
    "\n",
    "# Display the model's architecture\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "soDE0W_KH8rG"
   },
   "source": [
    "## 在训练期间保存模型（以 checkpoints 形式保存）"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "mRyd5qQQIXZm"
   },
   "source": [
    "您可以使用经过训练的模型而无需重新训练，或者在训练过程中断的情况下从离开处继续训练。`tf.keras.callbacks.ModelCheckpoint` 回调允许您在训练*期间*和*结束*时持续保存模型。\n",
    "\n",
    "### Checkpoint 回调用法\n",
    "\n",
    "创建一个只在训练期间保存权重的 `tf.keras.callbacks.ModelCheckpoint` 回调："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:04:03.464469Z",
     "iopub.status.busy": "2023-11-08T00:04:03.463687Z",
     "iopub.status.idle": "2023-11-08T00:04:07.839525Z",
     "shell.execute_reply": "2023-11-08T00:04:07.838585Z"
    },
    "id": "IFPuhwntH8VH"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
      "I0000 00:00:1699401845.104377  594618 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      " 1/32 [..............................] - ETA: 59s - loss: 2.4154 - sparse_categorical_accuracy: 0.0938"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "21/32 [==================>...........] - ETA: 0s - loss: 1.3630 - sparse_categorical_accuracy: 0.6250 "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 1: saving model to training_1/cp.ckpt\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "32/32 [==============================] - 2s 13ms/step - loss: 1.1396 - sparse_categorical_accuracy: 0.6780 - val_loss: 0.6977 - val_sparse_categorical_accuracy: 0.7850\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2/10\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      " 1/32 [..............................] - ETA: 0s - loss: 0.5334 - sparse_categorical_accuracy: 0.8125"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "21/32 [==================>...........] - ETA: 0s - loss: 0.4101 - sparse_categorical_accuracy: 0.8869"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 2: saving model to training_1/cp.ckpt\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "32/32 [==============================] - 0s 7ms/step - loss: 0.4015 - sparse_categorical_accuracy: 0.8870 - val_loss: 0.5432 - val_sparse_categorical_accuracy: 0.8300\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 3/10\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      " 1/32 [..............................] - ETA: 0s - loss: 0.3153 - sparse_categorical_accuracy: 0.9375"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "21/32 [==================>...........] - ETA: 0s - loss: 0.3156 - sparse_categorical_accuracy: 0.9152"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 3: saving model to training_1/cp.ckpt\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "32/32 [==============================] - 0s 7ms/step - loss: 0.2896 - sparse_categorical_accuracy: 0.9240 - val_loss: 0.4702 - val_sparse_categorical_accuracy: 0.8460\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 4/10\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      " 1/32 [..............................] - ETA: 0s - loss: 0.2313 - sparse_categorical_accuracy: 0.9062"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "22/32 [===================>..........] - ETA: 0s - loss: 0.1843 - sparse_categorical_accuracy: 0.9531"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 4: saving model to training_1/cp.ckpt\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "32/32 [==============================] - 0s 6ms/step - loss: 0.1912 - sparse_categorical_accuracy: 0.9540 - val_loss: 0.4457 - val_sparse_categorical_accuracy: 0.8550\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 5/10\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      " 1/32 [..............................] - ETA: 0s - loss: 0.0893 - sparse_categorical_accuracy: 1.0000"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "21/32 [==================>...........] - ETA: 0s - loss: 0.1528 - sparse_categorical_accuracy: 0.9717"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 5: saving model to training_1/cp.ckpt\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "32/32 [==============================] - 0s 7ms/step - loss: 0.1541 - sparse_categorical_accuracy: 0.9670 - val_loss: 0.4262 - val_sparse_categorical_accuracy: 0.8560\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 6/10\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      " 1/32 [..............................] - ETA: 0s - loss: 0.0743 - sparse_categorical_accuracy: 1.0000"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "21/32 [==================>...........] - ETA: 0s - loss: 0.1017 - sparse_categorical_accuracy: 0.9881"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 6: saving model to training_1/cp.ckpt\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "32/32 [==============================] - 0s 7ms/step - loss: 0.1096 - sparse_categorical_accuracy: 0.9830 - val_loss: 0.4280 - val_sparse_categorical_accuracy: 0.8540\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 7/10\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      " 1/32 [..............................] - ETA: 0s - loss: 0.1880 - sparse_categorical_accuracy: 0.9375"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "21/32 [==================>...........] - ETA: 0s - loss: 0.0838 - sparse_categorical_accuracy: 0.9866"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 7: saving model to training_1/cp.ckpt\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "32/32 [==============================] - 0s 7ms/step - loss: 0.0869 - sparse_categorical_accuracy: 0.9840 - val_loss: 0.4128 - val_sparse_categorical_accuracy: 0.8680\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 8/10\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      " 1/32 [..............................] - ETA: 0s - loss: 0.0568 - sparse_categorical_accuracy: 1.0000"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "21/32 [==================>...........] - ETA: 0s - loss: 0.0612 - sparse_categorical_accuracy: 0.9926"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 8: saving model to training_1/cp.ckpt\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "32/32 [==============================] - 0s 7ms/step - loss: 0.0647 - sparse_categorical_accuracy: 0.9910 - val_loss: 0.4074 - val_sparse_categorical_accuracy: 0.8620\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 9/10\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      " 1/32 [..............................] - ETA: 0s - loss: 0.0419 - sparse_categorical_accuracy: 1.0000"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "22/32 [===================>..........] - ETA: 0s - loss: 0.0540 - sparse_categorical_accuracy: 0.9986"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 9: saving model to training_1/cp.ckpt\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "32/32 [==============================] - 0s 6ms/step - loss: 0.0511 - sparse_categorical_accuracy: 0.9980 - val_loss: 0.4087 - val_sparse_categorical_accuracy: 0.8700\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 10/10\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      " 1/32 [..............................] - ETA: 0s - loss: 0.0259 - sparse_categorical_accuracy: 1.0000"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "22/32 [===================>..........] - ETA: 0s - loss: 0.0372 - sparse_categorical_accuracy: 1.0000"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 10: saving model to training_1/cp.ckpt\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "32/32 [==============================] - 0s 7ms/step - loss: 0.0401 - sparse_categorical_accuracy: 0.9990 - val_loss: 0.4369 - val_sparse_categorical_accuracy: 0.8560\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.src.callbacks.History at 0x7fb7a00a59d0>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "checkpoint_path = \"training_1/cp.ckpt\"\n",
    "checkpoint_dir = os.path.dirname(checkpoint_path)\n",
    "\n",
    "# Create a callback that saves the model's weights\n",
    "cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,\n",
    "                                                 save_weights_only=True,\n",
    "                                                 verbose=1)\n",
    "\n",
    "# Train the model with the new callback\n",
    "model.fit(train_images, \n",
    "          train_labels,  \n",
    "          epochs=10,\n",
    "          validation_data=(test_images, test_labels),\n",
    "          callbacks=[cp_callback])  # Pass callback to training\n",
    "\n",
    "# This may generate warnings related to saving the state of the optimizer.\n",
    "# These warnings (and similar warnings throughout this notebook)\n",
    "# are in place to discourage outdated usage, and can be ignored."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rlM-sgyJO084"
   },
   "source": [
    "这将创建一个 TensorFlow checkpoint 文件集合，这些文件在每个 epoch 结束时更新："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:04:07.843754Z",
     "iopub.status.busy": "2023-11-08T00:04:07.843045Z",
     "iopub.status.idle": "2023-11-08T00:04:07.848862Z",
     "shell.execute_reply": "2023-11-08T00:04:07.847946Z"
    },
    "id": "gXG5FVKFOVQ3"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['cp.ckpt.index', 'cp.ckpt.data-00000-of-00001', 'checkpoint']"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "os.listdir(checkpoint_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wlRN_f56Pqa9"
   },
   "source": [
    "只要两个模型共享相同的架构，您就可以在它们之间共享权重。因此，当从仅权重恢复模型时，创建一个与原始模型具有相同架构的模型，然后设置其权重。\n",
    "\n",
    "现在，重新构建一个未经训练的全新模型并基于测试集对其进行评估。未经训练的模型将以机会水平执行（约 10% 的准确率）："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:04:07.852691Z",
     "iopub.status.busy": "2023-11-08T00:04:07.852016Z",
     "iopub.status.idle": "2023-11-08T00:04:08.130001Z",
     "shell.execute_reply": "2023-11-08T00:04:08.129112Z"
    },
    "id": "Fp5gbuiaPqCT"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "32/32 - 0s - loss: 2.3531 - sparse_categorical_accuracy: 0.0720 - 188ms/epoch - 6ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Untrained model, accuracy:  7.20%\n"
     ]
    }
   ],
   "source": [
    "# Create a basic model instance\n",
    "model = create_model()\n",
    "\n",
    "# Evaluate the model\n",
    "loss, acc = model.evaluate(test_images, test_labels, verbose=2)\n",
    "print(\"Untrained model, accuracy: {:5.2f}%\".format(100 * acc))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1DTKpZssRSo3"
   },
   "source": [
    "然后从 checkpoint 加载权重并重新评估："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:04:08.134014Z",
     "iopub.status.busy": "2023-11-08T00:04:08.133346Z",
     "iopub.status.idle": "2023-11-08T00:04:08.291014Z",
     "shell.execute_reply": "2023-11-08T00:04:08.290129Z"
    },
    "id": "2IZxbwiRRSD2"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "32/32 - 0s - loss: 0.4369 - sparse_categorical_accuracy: 0.8560 - 90ms/epoch - 3ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Restored model, accuracy: 85.60%\n"
     ]
    }
   ],
   "source": [
    "# Loads the weights\n",
    "model.load_weights(checkpoint_path)\n",
    "\n",
    "# Re-evaluate the model\n",
    "loss, acc = model.evaluate(test_images, test_labels, verbose=2)\n",
    "print(\"Restored model, accuracy: {:5.2f}%\".format(100 * acc))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "bpAbKkAyVPV8"
   },
   "source": [
    "### checkpoint 回调选项\n",
    "\n",
    "回调提供了几个选项，为 checkpoint 提供唯一名称并调整 checkpoint 频率。\n",
    "\n",
    "训练一个新模型，每五个 epochs 保存一次唯一命名的 checkpoint ："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:04:08.294950Z",
     "iopub.status.busy": "2023-11-08T00:04:08.294407Z",
     "iopub.status.idle": "2023-11-08T00:04:17.842195Z",
     "shell.execute_reply": "2023-11-08T00:04:17.841201Z"
    },
    "id": "mQF_dlgIVOvq"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 5: saving model to training_2/cp-0005.ckpt\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 10: saving model to training_2/cp-0010.ckpt\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 15: saving model to training_2/cp-0015.ckpt\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 20: saving model to training_2/cp-0020.ckpt\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 25: saving model to training_2/cp-0025.ckpt\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 30: saving model to training_2/cp-0030.ckpt\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 35: saving model to training_2/cp-0035.ckpt\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 40: saving model to training_2/cp-0040.ckpt\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 45: saving model to training_2/cp-0045.ckpt\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 50: saving model to training_2/cp-0050.ckpt\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.src.callbacks.History at 0x7fb7203975b0>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Include the epoch in the file name (uses `str.format`)\n",
    "checkpoint_path = \"training_2/cp-{epoch:04d}.ckpt\"\n",
    "checkpoint_dir = os.path.dirname(checkpoint_path)\n",
    "\n",
    "batch_size = 32\n",
    "\n",
    "# Calculate the number of batches per epoch\n",
    "import math\n",
    "n_batches = len(train_images) / batch_size\n",
    "n_batches = math.ceil(n_batches)    # round up the number of batches to the nearest whole integer\n",
    "\n",
    "# Create a callback that saves the model's weights every 5 epochs\n",
    "cp_callback = tf.keras.callbacks.ModelCheckpoint(\n",
    "    filepath=checkpoint_path, \n",
    "    verbose=1, \n",
    "    save_weights_only=True,\n",
    "    save_freq=5*n_batches)\n",
    "\n",
    "# Create a new model instance\n",
    "model = create_model()\n",
    "\n",
    "# Save the weights using the `checkpoint_path` format\n",
    "model.save_weights(checkpoint_path.format(epoch=0))\n",
    "\n",
    "# Train the model with the new callback\n",
    "model.fit(train_images, \n",
    "          train_labels,\n",
    "          epochs=50, \n",
    "          batch_size=batch_size, \n",
    "          callbacks=[cp_callback],\n",
    "          validation_data=(test_images, test_labels),\n",
    "          verbose=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1zFrKTjjavWI"
   },
   "source": [
    "现在，检查生成的检查点并选择最新检查点："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:04:17.846493Z",
     "iopub.status.busy": "2023-11-08T00:04:17.845856Z",
     "iopub.status.idle": "2023-11-08T00:04:17.852002Z",
     "shell.execute_reply": "2023-11-08T00:04:17.851127Z"
    },
    "id": "p64q3-V4sXt0"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['cp-0050.ckpt.index',\n",
       " 'cp-0045.ckpt.data-00000-of-00001',\n",
       " 'cp-0005.ckpt.data-00000-of-00001',\n",
       " 'cp-0000.ckpt.index',\n",
       " 'cp-0000.ckpt.data-00000-of-00001',\n",
       " 'cp-0045.ckpt.index',\n",
       " 'cp-0035.ckpt.index',\n",
       " 'cp-0015.ckpt.data-00000-of-00001',\n",
       " 'cp-0025.ckpt.index',\n",
       " 'cp-0040.ckpt.data-00000-of-00001',\n",
       " 'cp-0040.ckpt.index',\n",
       " 'cp-0005.ckpt.index',\n",
       " 'cp-0010.ckpt.index',\n",
       " 'cp-0030.ckpt.index',\n",
       " 'cp-0015.ckpt.index',\n",
       " 'cp-0035.ckpt.data-00000-of-00001',\n",
       " 'cp-0020.ckpt.index',\n",
       " 'cp-0030.ckpt.data-00000-of-00001',\n",
       " 'cp-0020.ckpt.data-00000-of-00001',\n",
       " 'cp-0010.ckpt.data-00000-of-00001',\n",
       " 'cp-0025.ckpt.data-00000-of-00001',\n",
       " 'checkpoint',\n",
       " 'cp-0050.ckpt.data-00000-of-00001']"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "os.listdir(checkpoint_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:04:17.855542Z",
     "iopub.status.busy": "2023-11-08T00:04:17.854890Z",
     "iopub.status.idle": "2023-11-08T00:04:17.861011Z",
     "shell.execute_reply": "2023-11-08T00:04:17.860225Z"
    },
    "id": "1AN_fnuyR41H"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'training_2/cp-0050.ckpt'"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "latest = tf.train.latest_checkpoint(checkpoint_dir)\n",
    "latest"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Zk2ciGbKg561"
   },
   "source": [
    "注：默认 TensorFlow 格式只保存最近的 5 个检查点。\n",
    "\n",
    "要进行测试，请重置模型并加载最新检查点："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:04:17.864891Z",
     "iopub.status.busy": "2023-11-08T00:04:17.864247Z",
     "iopub.status.idle": "2023-11-08T00:04:18.150300Z",
     "shell.execute_reply": "2023-11-08T00:04:18.149558Z"
    },
    "id": "3M04jyK-H3QK"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "32/32 - 0s - loss: 0.4731 - sparse_categorical_accuracy: 0.8800 - 184ms/epoch - 6ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Restored model, accuracy: 88.00%\n"
     ]
    }
   ],
   "source": [
    "# Create a new model instance\n",
    "model = create_model()\n",
    "\n",
    "# Load the previously saved weights\n",
    "model.load_weights(latest)\n",
    "\n",
    "# Re-evaluate the model\n",
    "loss, acc = model.evaluate(test_images, test_labels, verbose=2)\n",
    "print(\"Restored model, accuracy: {:5.2f}%\".format(100 * acc))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "c2OxsJOTHxia"
   },
   "source": [
    "## 这些文件是什么？"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JtdYhvWnH2ib"
   },
   "source": [
    "上述代码可将权重存储到[检查点](../../guide/checkpoint.ipynb)格式文件（仅包含二进制格式训练权重） 的合集中。检查点包含：\n",
    "\n",
    "- 一个或多个包含模型权重的分片。\n",
    "- 一个索引文件，指示哪些权重存储在哪个分片中。\n",
    "\n",
    "如果您在一台计算机上训练模型，您将获得一个具有如下后缀的分片：`.data-00000-of-00001`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "S_FA-ZvxuXQV"
   },
   "source": [
    "## 手动保存权重\n",
    "\n",
    "要手动保存权重，请使用 `tf.keras.Model.save_weights`。默认情况下，`tf.keras`（尤其是 `Model.save_weights` 方法）使用扩展名为 `.ckpt` 的 TensorFlow [检查点](../../guide/checkpoint.ipynb)格式。要以扩展名为 `.h5` 的 HDF5 格式保存，请参阅[保存和加载模型](https://tensorflow.google.cn/guide/keras/save_and_serialize)指南。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:04:18.154276Z",
     "iopub.status.busy": "2023-11-08T00:04:18.153999Z",
     "iopub.status.idle": "2023-11-08T00:04:18.449398Z",
     "shell.execute_reply": "2023-11-08T00:04:18.448376Z"
    },
    "id": "R7W5plyZ-u9X"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "32/32 - 0s - loss: 0.4731 - sparse_categorical_accuracy: 0.8800 - 185ms/epoch - 6ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Restored model, accuracy: 88.00%\n"
     ]
    }
   ],
   "source": [
    "# Save the weights\n",
    "model.save_weights('./checkpoints/my_checkpoint')\n",
    "\n",
    "# Create a new model instance\n",
    "model = create_model()\n",
    "\n",
    "# Restore the weights\n",
    "model.load_weights('./checkpoints/my_checkpoint')\n",
    "\n",
    "# Evaluate the model\n",
    "loss, acc = model.evaluate(test_images, test_labels, verbose=2)\n",
    "print(\"Restored model, accuracy: {:5.2f}%\".format(100 * acc))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "kOGlxPRBEvV1"
   },
   "source": [
    "## 保存整个模型\n",
    "\n",
    "调用 `tf.keras.Model.save`，将模型的架构、权重和训练配置保存在单个 `model.keras` zip 存档中。\n",
    "\n",
    "整个模型可以保存为三种不同的文件格式（新的 `.keras` 格式和两种旧格式：`SavedModel` 和 `HDF5`）。将模型保存为 `path/to/model.keras` 会自动以最新格式保存。\n",
    "\n",
    "**注意**：对于 Keras 对象，建议使用新的高级 `.keras` 格式进行更丰富的基于名称的保存和重新加载，这样更易于调试。现有代码继续支持低级 SavedModel 格式和旧版 H5 格式。\n",
    "\n",
    "您可以通过以下方式切换到 SavedModel 格式：\n",
    "\n",
    "- 将 `save_format='tf'` 传递到 `save()`\n",
    "- 传递不带扩展名的文件名\n",
    "\n",
    "您可以通过以下方式切换到 H5 格式：\n",
    "\n",
    "- 将 `save_format='h5'` 传递到 `save()`\n",
    "- 传递以 `.h5` 结尾的文件名\n",
    "\n",
    "Saving a fully-functional model is very useful—you can load them in TensorFlow.js ([Saved Model](https://tensorflow.google.cn/js/tutorials/conversion/import_saved_model), [HDF5](https://tensorflow.google.cn/js/tutorials/conversion/import_keras)) and then train and run them in web browsers, or convert them to run on mobile devices using TensorFlow Lite ([Saved Model](https://tensorflow.google.cn/lite/models/convert/#convert_a_savedmodel_recommended_), [HDF5](https://tensorflow.google.cn/lite/models/convert/#convert_a_keras_model_))\n",
    "\n",
    "*Custom objects (for example, subclassed models or layers) require special attention when saving and loading. Refer to the **Saving custom objects** section below."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0fRGnlHMrkI7"
   },
   "source": [
    "### 新的高级 `.keras` 格式"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "eqO8jj7GsCDn"
   },
   "source": [
    "以 `.keras` 扩展名标记的新 Keras v3 保存格式是一种更简单、更高效的格式，它实现了基于名称的保存，从 Python 的角度确保您加载的内容与您保存的内容完全相同。这使得调试更容易，并且它是 Keras 的推荐格式。\n",
    "\n",
    "下面的部分说明了如何以 `.keras` 格式保存和恢复模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:04:18.453699Z",
     "iopub.status.busy": "2023-11-08T00:04:18.453111Z",
     "iopub.status.idle": "2023-11-08T00:04:19.747619Z",
     "shell.execute_reply": "2023-11-08T00:04:19.746803Z"
    },
    "id": "3f55mAXwukUX"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/5\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      " 1/32 [..............................] - ETA: 21s - loss: 2.3923 - sparse_categorical_accuracy: 0.0938"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "21/32 [==================>...........] - ETA: 0s - loss: 1.4819 - sparse_categorical_accuracy: 0.5551 "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "32/32 [==============================] - 1s 3ms/step - loss: 1.2257 - sparse_categorical_accuracy: 0.6420\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2/5\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      " 1/32 [..............................] - ETA: 0s - loss: 0.4663 - sparse_categorical_accuracy: 0.8750"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "23/32 [====================>.........] - ETA: 0s - loss: 0.4353 - sparse_categorical_accuracy: 0.8845"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "32/32 [==============================] - 0s 2ms/step - loss: 0.4333 - sparse_categorical_accuracy: 0.8820\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 3/5\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      " 1/32 [..............................] - ETA: 0s - loss: 0.3810 - sparse_categorical_accuracy: 0.8438"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "21/32 [==================>...........] - ETA: 0s - loss: 0.3120 - sparse_categorical_accuracy: 0.9286"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "32/32 [==============================] - 0s 3ms/step - loss: 0.3014 - sparse_categorical_accuracy: 0.9260\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 4/5\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      " 1/32 [..............................] - ETA: 0s - loss: 0.1449 - sparse_categorical_accuracy: 1.0000"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "23/32 [====================>.........] - ETA: 0s - loss: 0.2208 - sparse_categorical_accuracy: 0.9524"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "32/32 [==============================] - 0s 2ms/step - loss: 0.2073 - sparse_categorical_accuracy: 0.9580\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 5/5\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      " 1/32 [..............................] - ETA: 0s - loss: 0.1947 - sparse_categorical_accuracy: 1.0000"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "23/32 [====================>.........] - ETA: 0s - loss: 0.1542 - sparse_categorical_accuracy: 0.9606"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "32/32 [==============================] - 0s 2ms/step - loss: 0.1596 - sparse_categorical_accuracy: 0.9620\n"
     ]
    }
   ],
   "source": [
    "# Create and train a new model instance.\n",
    "model = create_model()\n",
    "model.fit(train_images, train_labels, epochs=5)\n",
    "\n",
    "# Save the entire model as a `.keras` zip archive.\n",
    "model.save('my_model.keras')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "iHqwaun5g8lD"
   },
   "source": [
    "从 `.keras` zip 归档重新加载新的 Keras 模型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:04:19.752206Z",
     "iopub.status.busy": "2023-11-08T00:04:19.751546Z",
     "iopub.status.idle": "2023-11-08T00:04:19.912501Z",
     "shell.execute_reply": "2023-11-08T00:04:19.911729Z"
    },
    "id": "HyfUMOZwux_-"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential_5\"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " Layer (type)                Output Shape              Param #   \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=================================================================\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " dense_10 (Dense)            (None, 512)               401920    \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                                                                 \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " dropout_5 (Dropout)         (None, 512)               0         \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                                                                 \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " dense_11 (Dense)            (None, 10)                5130      \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                                                                 \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=================================================================\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total params: 407050 (1.55 MB)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Trainable params: 407050 (1.55 MB)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Non-trainable params: 0 (0.00 Byte)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "new_model = tf.keras.models.load_model('my_model.keras')\n",
    "\n",
    "# Show the model architecture\n",
    "new_model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "9Cn3pSBqvJ5f"
   },
   "source": [
    "尝试使用加载的模型运行评估和预测："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:04:19.920373Z",
     "iopub.status.busy": "2023-11-08T00:04:19.920106Z",
     "iopub.status.idle": "2023-11-08T00:04:20.372924Z",
     "shell.execute_reply": "2023-11-08T00:04:20.372095Z"
    },
    "id": "8BT4mHNIvMdW"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "32/32 - 0s - loss: 0.4198 - sparse_categorical_accuracy: 0.8670 - 187ms/epoch - 6ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Restored model, accuracy: 86.70%\n",
      "\r\n",
      " 1/32 [..............................] - ETA: 2s"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "32/32 [==============================] - 0s 1ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1000, 10)\n"
     ]
    }
   ],
   "source": [
    "# Evaluate the restored model\n",
    "loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)\n",
    "print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))\n",
    "\n",
    "print(new_model.predict(test_images).shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "kPyhgcoVzqUB"
   },
   "source": [
    "### SavedModel 格式"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "LtcN4VIb7JkK"
   },
   "source": [
    "SavedModel 格式是另一种序列化模型的方式。以这种格式保存的模型可以使用 `tf.keras.models.load_model` 还原，并且与 TensorFlow Serving 兼容。[SavedModel 指南](../../guide/saved_model.ipynb)详细介绍了如何 `serve/inspect` SavedModel。以下部分说明了保存和恢复模型的步骤。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:04:20.377032Z",
     "iopub.status.busy": "2023-11-08T00:04:20.376433Z",
     "iopub.status.idle": "2023-11-08T00:04:22.350639Z",
     "shell.execute_reply": "2023-11-08T00:04:22.349908Z"
    },
    "id": "sI1YvCDFzpl3"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/5\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      " 1/32 [..............................] - ETA: 21s - loss: 2.4572 - sparse_categorical_accuracy: 0.0000e+00"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "20/32 [=================>............] - ETA: 0s - loss: 1.4852 - sparse_categorical_accuracy: 0.5594     "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "32/32 [==============================] - 1s 3ms/step - loss: 1.1823 - sparse_categorical_accuracy: 0.6530\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2/5\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      " 1/32 [..............................] - ETA: 0s - loss: 0.6145 - sparse_categorical_accuracy: 0.8438"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "21/32 [==================>...........] - ETA: 0s - loss: 0.4739 - sparse_categorical_accuracy: 0.8527"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "32/32 [==============================] - 0s 3ms/step - loss: 0.4464 - sparse_categorical_accuracy: 0.8660\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 3/5\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      " 1/32 [..............................] - ETA: 0s - loss: 0.1838 - sparse_categorical_accuracy: 0.9688"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "22/32 [===================>..........] - ETA: 0s - loss: 0.2747 - sparse_categorical_accuracy: 0.9276"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "32/32 [==============================] - 0s 3ms/step - loss: 0.2765 - sparse_categorical_accuracy: 0.9290\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 4/5\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      " 1/32 [..............................] - ETA: 0s - loss: 0.1509 - sparse_categorical_accuracy: 0.9688"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "22/32 [===================>..........] - ETA: 0s - loss: 0.2231 - sparse_categorical_accuracy: 0.9474"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "32/32 [==============================] - 0s 3ms/step - loss: 0.2195 - sparse_categorical_accuracy: 0.9490\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 5/5\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      " 1/32 [..............................] - ETA: 0s - loss: 0.1657 - sparse_categorical_accuracy: 0.9688"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "22/32 [===================>..........] - ETA: 0s - loss: 0.1510 - sparse_categorical_accuracy: 0.9716"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "32/32 [==============================] - 0s 3ms/step - loss: 0.1523 - sparse_categorical_accuracy: 0.9660\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: saved_model/my_model/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: saved_model/my_model/assets\n"
     ]
    }
   ],
   "source": [
    "# Create and train a new model instance.\n",
    "model = create_model()\n",
    "model.fit(train_images, train_labels, epochs=5)\n",
    "\n",
    "# Save the entire model as a SavedModel.\n",
    "!mkdir -p saved_model\n",
    "model.save('saved_model/my_model') "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "iUvT_3qE8hV5"
   },
   "source": [
    "SavedModel 格式是一个包含 protobuf 二进制文件和 TensorFlow 检查点的目录。检查保存的模型目录："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:04:22.354744Z",
     "iopub.status.busy": "2023-11-08T00:04:22.354450Z",
     "iopub.status.idle": "2023-11-08T00:04:22.669329Z",
     "shell.execute_reply": "2023-11-08T00:04:22.668282Z"
    },
    "id": "sq8fPglI1RWA"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "my_model\r\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "assets\tfingerprint.pb\tkeras_metadata.pb  saved_model.pb  variables\r\n"
     ]
    }
   ],
   "source": [
    "# my_model directory\n",
    "!ls saved_model\n",
    "\n",
    "# Contains an assets folder, saved_model.pb, and variables folder.\n",
    "!ls saved_model/my_model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "B7qfpvpY9HCe"
   },
   "source": [
    "从保存的模型重新加载一个新的 Keras 模型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:04:22.674215Z",
     "iopub.status.busy": "2023-11-08T00:04:22.673880Z",
     "iopub.status.idle": "2023-11-08T00:04:23.137929Z",
     "shell.execute_reply": "2023-11-08T00:04:23.137208Z"
    },
    "id": "0YofwHdN0pxa"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.5\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.6\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.7\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.7\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.8\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.8\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.5\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.6\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.7\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.7\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.8\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.8\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential_6\"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " Layer (type)                Output Shape              Param #   \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=================================================================\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " dense_12 (Dense)            (None, 512)               401920    \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                                                                 \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " dropout_6 (Dropout)         (None, 512)               0         \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                                                                 \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " dense_13 (Dense)            (None, 10)                5130      \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                                                                 \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=================================================================\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total params: 407050 (1.55 MB)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Trainable params: 407050 (1.55 MB)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Non-trainable params: 0 (0.00 Byte)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "new_model = tf.keras.models.load_model('saved_model/my_model')\n",
    "\n",
    "# Check its architecture\n",
    "new_model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "uWwgNaz19TH2"
   },
   "source": [
    "使用与原始模型相同的实参编译恢复的模型。尝试使用加载的模型运行评估和预测："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:04:23.145460Z",
     "iopub.status.busy": "2023-11-08T00:04:23.145199Z",
     "iopub.status.idle": "2023-11-08T00:04:23.569906Z",
     "shell.execute_reply": "2023-11-08T00:04:23.569154Z"
    },
    "id": "Yh5Mu0yOgE5J"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "32/32 - 0s - loss: 0.4436 - sparse_categorical_accuracy: 0.8560 - 189ms/epoch - 6ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Restored model, accuracy: 85.60%\n",
      "\r\n",
      " 1/32 [..............................] - ETA: 2s"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "32/32 [==============================] - 0s 1ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1000, 10)\n"
     ]
    }
   ],
   "source": [
    "# Evaluate the restored model\n",
    "loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)\n",
    "print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))\n",
    "\n",
    "print(new_model.predict(test_images).shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SkGwf-50zLNn"
   },
   "source": [
    "### HDF5 格式\n",
    "\n",
    "Keras 使用 [HDF5](https://en.wikipedia.org/wiki/Hierarchical_Data_Format) 标准提供基本的旧版高级保存格式。 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:04:23.574213Z",
     "iopub.status.busy": "2023-11-08T00:04:23.573585Z",
     "iopub.status.idle": "2023-11-08T00:04:24.845421Z",
     "shell.execute_reply": "2023-11-08T00:04:24.844458Z"
    },
    "id": "m2dkmJVCGUia"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/5\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      " 1/32 [..............................] - ETA: 21s - loss: 2.3124 - sparse_categorical_accuracy: 0.0938"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "20/32 [=================>............] - ETA: 0s - loss: 1.3583 - sparse_categorical_accuracy: 0.6187 "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "32/32 [==============================] - 1s 3ms/step - loss: 1.1044 - sparse_categorical_accuracy: 0.6980\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2/5\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      " 1/32 [..............................] - ETA: 0s - loss: 0.2788 - sparse_categorical_accuracy: 0.9688"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "22/32 [===================>..........] - ETA: 0s - loss: 0.4147 - sparse_categorical_accuracy: 0.8821"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "32/32 [==============================] - 0s 3ms/step - loss: 0.3960 - sparse_categorical_accuracy: 0.8830\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 3/5\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      " 1/32 [..............................] - ETA: 0s - loss: 0.2157 - sparse_categorical_accuracy: 0.9688"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "22/32 [===================>..........] - ETA: 0s - loss: 0.2576 - sparse_categorical_accuracy: 0.9276"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "32/32 [==============================] - 0s 2ms/step - loss: 0.2616 - sparse_categorical_accuracy: 0.9280\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 4/5\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      " 1/32 [..............................] - ETA: 0s - loss: 0.1809 - sparse_categorical_accuracy: 1.0000"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "22/32 [===================>..........] - ETA: 0s - loss: 0.1991 - sparse_categorical_accuracy: 0.9531"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "32/32 [==============================] - 0s 3ms/step - loss: 0.1953 - sparse_categorical_accuracy: 0.9520\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 5/5\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      " 1/32 [..............................] - ETA: 0s - loss: 0.1050 - sparse_categorical_accuracy: 1.0000"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "22/32 [===================>..........] - ETA: 0s - loss: 0.1563 - sparse_categorical_accuracy: 0.9645"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "32/32 [==============================] - 0s 3ms/step - loss: 0.1439 - sparse_categorical_accuracy: 0.9670\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/engine/training.py:3103: UserWarning: You are saving your model as an HDF5 file via `model.save()`. This file format is considered legacy. We recommend using instead the native Keras format, e.g. `model.save('my_model.keras')`.\n",
      "  saving_api.save_model(\n"
     ]
    }
   ],
   "source": [
    "# Create and train a new model instance.\n",
    "model = create_model()\n",
    "model.fit(train_images, train_labels, epochs=5)\n",
    "\n",
    "# Save the entire model to a HDF5 file.\n",
    "# The '.h5' extension indicates that the model should be saved to HDF5.\n",
    "model.save('my_model.h5') "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "GWmttMOqS68S"
   },
   "source": [
    "现在，从该文件重新创建模型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:04:24.849451Z",
     "iopub.status.busy": "2023-11-08T00:04:24.848669Z",
     "iopub.status.idle": "2023-11-08T00:04:24.927225Z",
     "shell.execute_reply": "2023-11-08T00:04:24.926564Z"
    },
    "id": "5NDMO_7kS6Do"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential_7\"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " Layer (type)                Output Shape              Param #   \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=================================================================\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " dense_14 (Dense)            (None, 512)               401920    \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                                                                 \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " dropout_7 (Dropout)         (None, 512)               0         \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                                                                 \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " dense_15 (Dense)            (None, 10)                5130      \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                                                                 \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=================================================================\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total params: 407050 (1.55 MB)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Trainable params: 407050 (1.55 MB)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Non-trainable params: 0 (0.00 Byte)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "# Recreate the exact same model, including its weights and the optimizer\n",
    "new_model = tf.keras.models.load_model('my_model.h5')\n",
    "\n",
    "# Show the model architecture\n",
    "new_model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JXQpbTicTBwt"
   },
   "source": [
    "检查其准确率（accuracy）："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:04:24.934763Z",
     "iopub.status.busy": "2023-11-08T00:04:24.934505Z",
     "iopub.status.idle": "2023-11-08T00:04:25.179642Z",
     "shell.execute_reply": "2023-11-08T00:04:25.178886Z"
    },
    "id": "jwEaj9DnTCVA"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "32/32 - 0s - loss: 0.4303 - sparse_categorical_accuracy: 0.8610 - 191ms/epoch - 6ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Restored model, accuracy: 86.10%\n"
     ]
    }
   ],
   "source": [
    "loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)\n",
    "print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dGXqd4wWJl8O"
   },
   "source": [
    "Keras 通过检查模型的架构来保存这些模型。这种技术可以保存所有内容：\n",
    "\n",
    "- 权重值\n",
    "- 模型的架构\n",
    "- 模型的训练配置（您传递给 `.compile()` 方法的内容）\n",
    "- 优化器及其状态（如果有）（这样，您便可从中断的地方重新启动训练）\n",
    "\n",
    "Keras 无法保存 `v1.x` 优化器（来自 `tf.compat.v1.train`），因为它们与检查点不兼容。对于 v1.x 优化器，您需要在加载-失去优化器的状态后，重新编译模型。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "kAUKJQyGqTNH"
   },
   "source": [
    "### 保存自定义对象\n",
    "\n",
    "如果您使用的是 SavedModel 格式，则可以跳过此部分。高级 `.keras`/HDF5 格式与低级 SavedModel 格式之间的主要区别在于 `.keras`/HDF5 格式使用对象配置来保存模型架构，而 SavedModel 保存执行计算图。因此，SavedModels 能够保存自定义对象，例如子类化模型和自定义层，而无需原始代码。但是，因此调试低级 SavedModels 可能会更加困难，鉴于基于名称并且对于 Keras 是原生的特性，我们建议改用高级 `.keras` 格式。\n",
    "\n",
    "要将自定义对象保存到 `.keras` 和 HDF5，您必须执行以下操作：\n",
    "\n",
    "1. 在您的对象中定义一个 `get_config` 方法，并且可以选择定义一个 `from_config` 类方法。\n",
    "    - `get_config(self)` 返回重新创建对象所需的形参的 JSON 可序列化字典。\n",
    "    - `from_config(cls, config)` 使用从 `get_config` 返回的配置来创建一个新对象。默认情况下，此函数将使用配置作为初始化 kwarg (`return cls(**config)`)。\n",
    "2. 通过以下三种方式之一将自定义对象传递给模型：\n",
    "    - 使用 `@tf.keras.utils.register_keras_serializable` 装饰器注册自定义对象。**（推荐）**\n",
    "    - 加载模型时直接将对象传递给 `custom_objects` 实参。实参必须是将字符串类名映射到 Python 类的字典。例如 `tf.keras.models.load_model(path, custom_objects={'CustomLayer': CustomLayer})`\n",
    "    - 将 `tf.keras.utils.custom_object_scope` 与 `custom_objects` 字典实参中包含的对象一起使用，并在作用域内放置一个 `tf.keras.models.load_model(path){ /code2} 调用。`\n",
    "\n",
    "有关自定义对象和 `get_config` 的示例，请参阅[从头开始编写层和模型](https://tensorflow.google.cn/guide/keras/custom_layers_and_models)教程。\n"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "save_and_load.ipynb",
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
